Skip to content

Conversation

@ground0state
Copy link
Contributor

timm.data.Mixup class does not work when device is cpu.

from timm.data.mixup import Mixup

mixup_args = {
    'mixup_alpha': 1.,
    'cutmix_alpha': 0.,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0,
    'num_classes': 4
}

mixup_fn = Mixup(**mixup_args)
x, labels = next(iter(loader))
x, labels = mixup_fn(x, labels)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-22-ac462bac9f44> in <module>()
     21 loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
     22 x, labels = next(iter(loader))
---> 23 x, labels = mixup_fn(x, labels)
     24 
     25 

2 frames
/usr/local/lib/python3.7/dist-packages/timm/data/mixup.py in __call__(self, x, target)
    215         else:
    216             lam = self._mix_batch(x)
--> 217         target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
    218         return x, target
    219 

/usr/local/lib/python3.7/dist-packages/timm/data/mixup.py in mixup_target(target, num_classes, lam, smoothing, device)
     23     off_value = smoothing / num_classes
     24     on_value = 1. - smoothing + off_value
---> 25     y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
     26     y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
     27     return y1 * lam + y2 * (1. - lam)

/usr/local/lib/python3.7/dist-packages/timm/data/mixup.py in one_hot(x, num_classes, on_value, off_value, device)
     17 def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
     18     x = x.long().view(-1, 1)
---> 19     return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
     20 
     21 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument index in method wrapper_scatter__value)

So, I fixed timm.data.Mixup class.

@rwightman rwightman merged commit 2c33ca6 into huggingface:master Oct 12, 2021
guoriyue pushed a commit to guoriyue/pytorch-image-models that referenced this pull request May 24, 2024
Fix bugs that Mixup does not work when device is cpu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants